演算法2-梯度下降法及MCMC

林嶔 (Lin, Chin)

Lesson 18

梯度下降法(1)

– 我們現在已經了解到一件事情,無論函數型態寫成什麼樣子,最大概似估計法重點就是能讓我們定義出一個求解函數,所以我們把問題簡化一點:我們希望有一個方法,能在某個不特定函數中找出該函數的極值。

– 我們從這個簡單的例子開始介紹「梯度下降法」

梯度下降法(2)

– 還是很難理解吧,我們來想想y = x^2的求解過程,我們已知y = x^2的微分方程是2x,意思是說在任何一個點的切線斜率是2x,而斜率的意思就是說「x每增加一個單位,y所改變的量」

– 想到這裡,我們就能了解,假設我們想要求得y = x^2的最小值,我們可以隨機的給一個x的起始點,並且讓這個點以「切線斜率」的反方向移動,這樣就能找出最小值

– 我了解到實在太難理解了,我們用R語言實現一下他的過程吧

original.fun = function(x) {
  return(x^2)
}

differential.fun = function(x) {
  return(2*x)
}

x = seq(-6, 6, by = 0.01)
y = original.fun(x)

start.value = 5
learning.rate = 0.1
num.iteration = 20

result.x = rep(NA, num.iteration)
result.y = rep(NA, num.iteration)

par(mfcol = c(4, 5))

for (i in 1:num.iteration) {
  if (i == 1) {
    result.x[1] = start.value
    result.y[1] = original.fun(start.value)
  } else {
    result.x[i] = result.x[i-1] - learning.rate * differential.fun(result.x[i-1])
    result.y[i] = original.fun(result.x[i])
  }
  plot(x, y, xlim = c(-5, 5), ylim = c(0, 25), type = "l", main = paste0("iteration = ", i))
  col.points = rep("black", num.iteration)
  col.points[i] = "red"
  points(result.x, result.y, pch = 19, col = col.points)
}

F18_1

梯度下降法(3)

F18_2

– 在使用梯度下降法時,原則上learning.rate不宜設置太大,但可以觀察收斂速度,若收斂速度太慢再適當的調整為佳。

練習-1

F18_3

– 微分後的方程式如下:

F18_4

F18_5

x = c(1, 2, 3, 4, 5)
y = c(6, 7, 9, 8, 10)

original.fun = function(b0, b1, x = x, y = y) {
  y.hat = b0 + b1 * x
  return(sum(y.hat - y)^2/2/length(x))
}

differential.fun.b0 = function(b0, b1, x = x, y = y) {
  y.hat = b0 + b1 * x
  return(sum(y.hat - y)/length(x))
}

differential.fun.b1 = function(b0, b1, x = x, y = y) {
  y.hat = b0 + b1 * x
  return(sum((y.hat - y)*x)/length(x))
}

differential.fun.b0(b0 = 1, b1 = 0.1, x = x, y = y)
## [1] -6.7
differential.fun.b1(b0 = 3, b1 = 0.7, x = x, y = y)
## [1] -9.1
model = lm(y~x)
print(model)
## 
## Call:
## lm(formula = y ~ x)
## 
## Coefficients:
## (Intercept)            x  
##         5.3          0.9

馬可夫鏈蒙特卡羅法(1)

– 除此之外,梯度下降法還需要使用微分工具,有沒有不需要使用微分工具的方法呢?

– MCMC的想法跟梯度下降法很像,差別只在梯度下降法在移動時是沿著「梯度」的反向前進,而MCMC是完全隨機的移動

– 另外,由於MCMC是完全隨機的移動,在他移動進入極值時可以利用這個隨機移動的特性去找尋該區域的分布特性

馬可夫鏈蒙特卡羅法(2)

original.fun = function(x) {
  return(x^2)
}

random.walk = function(x) {
  x + runif(1, min = -0.1, max = 0.1)
}

set.seed(0)

start.value = 5
num.iteration = 1000

x = rep(NA, num.iteration)

for (i in 1:100) {
  if (i == 1) {
    x[i] = start.value
  } else {
    old.x = x[i-1]
    new.x = random.walk(old.x)
    if (original.fun(old.x) < original.fun(new.x)) {
      x[i] = old.x
    } else {
      x[i] = new.x
    }
  }
}

x[71:100]
##  [1] 3.482776 3.482776 3.450591 3.450591 3.419928 3.386683 3.381953
##  [8] 3.381953 3.381953 3.359951 3.359951 3.359951 3.346883 3.346883
## [15] 3.326882 3.291952 3.291952 3.232490 3.232490 3.156829 3.105927
## [22] 3.034587 2.982513 2.894300 2.894300 2.894300 2.894300 2.894300
## [29] 2.885355 2.867372
for (i in 101:1000) {
  old.x = x[i-1]
  new.x = random.walk(old.x)
  if (original.fun(old.x) < original.fun(new.x)) {
    x[i] = old.x
  } else {
    x[i] = new.x
  }
}

x[971:1000]
##  [1] -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651
##  [6] -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651
## [11] -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651
## [16] -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651
## [21] -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651
## [26] -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651 -0.0001650651

馬可夫鏈蒙特卡羅法(3)

– 先把機率函數寫出來

set.seed(0)

x = rnorm(1000)
y = 5 + 3 * x + rnorm(1000)

prop.fun = function(b0, b1, x = x, y = y) {
  y.hat = b0 + b1 * x
  res = y.hat - y
  log_p = dnorm(res, mean = 0, sd = sd(res), log = TRUE)  
  return(sum(log_p))
}
start.b0 = 0
start.b1 = 0
num.iteration = 10000

b0.seq = rep(NA, num.iteration)
b1.seq = rep(NA, num.iteration)

for (i in 1:num.iteration) {
  if (i == 1) {
    b0.seq[i] = start.b0
    b1.seq[i] = start.b1
  } else {
    b0.seq[i] = random.walk(b0.seq[i-1])
    b1.seq[i] = random.walk(b1.seq[i-1])
    old.log_p = prop.fun(b0 = b0.seq[i-1], b1 = b1.seq[i-1], x = x, y = y)
    new.log_p = prop.fun(b0 = b0.seq[i], b1 = b1.seq[i], x = x, y = y)
    diff.p = exp(new.log_p - old.log_p)
    if (diff.p < runif(1, min = 0, max = 1)) {
      b0.seq[i] = b0.seq[i-1] 
      b1.seq[i] = b1.seq[i-1] 
    }
  }
}

馬可夫鏈蒙特卡羅法(4)

par(mfcol = c(1, 2))
plot(1:num.iteration, b0.seq, type = "l")
plot(1:num.iteration, b1.seq, type = "l")

– 我們設置一個Burn-In time,看看在開始上下震動之後的結果

burn_in = 5000

par(mfcol = c(2, 2))
hist(b0.seq[(burn_in+1):num.iteration])
abline(v = mean(b0.seq[(burn_in+1):num.iteration]), col = "blue")
abline(v = 5, col = "red")
plot((burn_in+1):num.iteration, b0.seq[(burn_in+1):num.iteration], type = "l")
abline(h = mean(b0.seq[(burn_in+1):num.iteration]), col = "blue")
abline(h = 5, col = "red")
hist(b1.seq[(burn_in+1):num.iteration])
abline(v = mean(b1.seq[(burn_in+1):num.iteration]), col = "blue")
abline(v = 3, col = "red")
plot((burn_in+1):num.iteration, b1.seq[(burn_in+1):num.iteration], type = "l")
abline(h = mean(b1.seq[(burn_in+1):num.iteration]), col = "blue")
abline(h = 3, col = "red")

fit = lm(y~x)
summary(fit)
## 
## Call:
## lm(formula = y ~ x)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -3.0180 -0.7024 -0.0008  0.7504  3.0451 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  4.97501    0.03271   152.1   <2e-16 ***
## x            2.98693    0.03279    91.1   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1.034 on 998 degrees of freedom
## Multiple R-squared:  0.8927, Adjusted R-squared:  0.8926 
## F-statistic:  8300 on 1 and 998 DF,  p-value: < 2.2e-16
mean(b0.seq[(burn_in+1):num.iteration]) #Anwser ~ 4.975
## [1] 4.975639
mean(b1.seq[(burn_in+1):num.iteration]) #Anwser ~ 2.987
## [1] 2.987548
sd(b0.seq[(burn_in+1):num.iteration]) #Anwser ~ 0.0327
## [1] 0.03229219
sd(b1.seq[(burn_in+1):num.iteration]) #Anwser ~ 0.0328
## [1] 0.03296865

練習-2

– 能請你試試看利用MCMC求出邏輯斯回歸的結果嗎?

set.seed(0)

x = rnorm(1000)
lr = -1 + 1 * x
prop = 1/(1+exp(-lr))
y = as.numeric(prop > runif(1000))

prop.fun = function(b0, b1, x = x, y = y) {
  LR = b0 + b1 * x
  PROP = 1/(1+exp(-LR))
  log_p = y*log(PROP) + (1-y)*log(1-PROP)
  return(sum(log_p))
}

prop.fun(b0 = -1, b1 = 1, x = x, y = y)
## [1] -540.8725
prop.fun(b0 = 1, b1 = 0.5, x = x, y = y)
## [1] -932.6171
glm(y~x, family = "binomial")
## 
## Call:  glm(formula = y ~ x, family = "binomial")
## 
## Coefficients:
## (Intercept)            x  
##     -0.9258       0.9707  
## 
## Degrees of Freedom: 999 Total (i.e. Null);  998 Residual
## Null Deviance:       1241 
## Residual Deviance: 1081  AIC: 1085

局部極值問題(1)

– 我們一樣先從簡單的問題開始,我們希望求y = 2x^2 +3x^3 +x^4的極值,但我們先看看函數圖形。

F18_6

differential.fun = function(x) {
  return(4*x + 9*x^2 + 4*x^3)
}

start.value = 0.5 
num.iteration = 300
learning.rate = 0.01

result.x = rep(NA, num.iteration)

for (i in 1:num.iteration) {
  if (i == 1) {
    result.x[1] = start.value
  } else {
    result.x[i] = result.x[i-1] - learning.rate * differential.fun(result.x[i-1])
  }
}

tail(result.x, 1)
## [1] 1.08942e-06
start.value = -3 
num.iteration = 300
learning.rate = 0.01

result.x = rep(NA, num.iteration)

for (i in 1:num.iteration) {
  if (i == 1) {
    result.x[1] = start.value
  } else {
    result.x[i] = result.x[i-1] - learning.rate * differential.fun(result.x[i-1])
  }
}

tail(result.x, 1)
## [1] -1.640388

局部極值問題(2)

– 你可以親自在上面的線性迴歸/邏輯斯迴歸中設定不同的起始值,你會發現線性迴歸、邏輯斯迴歸都沒有這個問題,這是因為他們的函數圖形都比較像y = x^2這種凸函數

– 而y = 2x^2 +3x^3 +x^4這種叫做凹凸函數

  1. 數學解析解一定能求得全局最小值

  2. 梯度下降法/MCMC等近似解求法僅能求出局部最小值

  3. 局部最小值在凹凸函數與求解起始值有相當大的關係

  4. 求解起始值對凸函數在局部最小值的求解上沒有影響

局部極值問題(3)

– 從統計檢定的角度,統計檢定的目標在檢定其值是否不等於0,因此我們可以讓起始值從0開始,若其收斂在0附近則不顯著;若其收斂在其他位置,不管答案是否為全局最佳解,但已可證明虛無假設為誤!

– 從預測未知結果的角度,局部最佳解所求得的方程式儘管不是預測能力最好的方程式,但仍然有一定的預測能力,所以仍然有其實用性。

演算法總結

– 儘管學習統計方法後的數理推導並非使用他們的必要條件,但了解其背後原理後更容易了解這些方法的應用範圍和限制

– 另外這兩週的課程也讓你能初步了解到統計方法的發展方法,如果你發現了X與Y之間真實的關係,你能試著把他們寫下來,並且用這兩週所學習到的演算法求解!

– 你是否迫不及待發展自己的統計方法了?

– 我們在生物統計學到的統計方法大多是線性統計,而這些統計方法的求解機率函數都是凸函數,但之後的機器學習方法大多是非線性統計,求解機率函數皆為凹凸函數,這會在使用上遇到非常多一般統計方法不會遇到的問題。